import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn 
import os.path as osp  
import numpy as np  
import random

import os
import argparse

from building import *
from utils import progress_bar
from accuracy import Acc

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s 


def one_hot(preds):
    oh = torch.zeros(preds.size(0), preds.max()+1)
    for i in range(oh.size(0)):
        oh[i,preds[i]] = 1
    return oh

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
# parser.add_argument('--net', default='resnet50', type=str, help='architecture')
parser.add_argument('--bs', default=64, type=int, help='batch size')
parser.add_argument('--input_size','-i', default=224, type=int, help='input image size')
parser.add_argument('--crop_size', default=256, type=int, help='crop image size')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--gpu', '-g', default='0', type=str, help='gpu id')
parser.add_argument('--workers','-w', default=4, type=int)

parser.add_argument('--tgt', default='cifar100', type=str, help='target set')
parser.add_argument('--src', default='imagenet', type=str, help='source set')
parser.add_argument('--ssl', default=None, type=str, help='source ssl method')
parser.add_argument('--metric', default='FaCe', type=str, help='estimation metric')
parser.add_argument('--norm', default=None, type=str)
parser.add_argument('--use_pred', action='store_true')
parser.add_argument('--tune_head', action='store_true')
parser.add_argument('--use_test', action='store_true')
parser.add_argument('--nc', action='store_true')
parser.add_argument('--tsne', action='store_true')
parser.add_argument('--tuning_mode', default=1, type=int)
parser.add_argument('--pca', default=0, type=int, help='estimation metric')
parser.add_argument('--log', default='singlezoo', type=str, help='estimation metric')
parser.add_argument('--T', default=0.05, type=float)
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu  
SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
log_dir = './TE_log/singlezoo'
if args.nc:
    log_dir = './TE_log/singlezoo_ft_nc'
if not os.path.isdir(log_dir):
    os.makedirs(log_dir)
args.out_file = open(osp.join(log_dir, f'{args.log}_{args.src}-to-{args.tgt}_{args.metric}_{args.pca}_{str(args.use_pred)}_T_{args.T}.txt'), 'w')
args.out_file.write(print_args(args)+'\n')
args.out_file.flush() 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

print('==> Preparing data..')
trainset, trainloader, testset, testloader = build_tgt_datasets(args)

gt = Acc(args.tgt).acc
if args.metric == 'bn':
    gt = gt[:-2]
te_scores = []

args.num_cls_src = cls_num_query(args.src)
args.num_cls_tgt = cls_num_query(args.tgt)

nets = ['resnet50', 'resnet101', 'resnet152', 'densenet121', 'densenet169', 'densenet201', 'mobilenetv1', 'mobilenetv2', 'mobilenetv3_large', 'efficientnetb0', 'efficientnetb1', 'efficientnetb2', 'efficientnetb3', 'vgg16', 'vgg19']
# nets = [ 'resnet152',  'efficientnetb2']

for nw in nets:
    if args.metric == 'bn' and nw.startswith('vgg'):
        continue
    args.net = nw
    if args.net.startswith('vgg'):
        args.lr = 1e-3
    
    ckpt_src_dir = 'checkpoint/src/'+args.src+'/'+args.net
    if args.nc:
        ckpt_src_dir = 'checkpoint/' + args.src+'2'+args.tgt + '/'+args.net
        # args.num_cls_src  = args.num_cls_tgt

    # ckpt_tgt_dir = 'checkpoint/' + args.src+'2'+args.tgt+'/'+args.net
    # ckpt_tgt = ckpt_tgt_dir + '/ckpt.pth'

    # Model
    print('==> Building model..')

    net, ckpt_src = build_net_imgnet_pretrain(args, ckpt_src_dir)
    if args.nc:
        ckpt_src = ckpt_src_dir + '/ckpt.pth'

    print('==> Resuming from source checkpoint..')

    if not args.src == 'imagenet':
        ckpt = torch.load(ckpt_src)
        net.fc = torch.nn.Linear(2048, args.num_cls_src)
        net.load_state_dict(ckpt['net'], strict=True)
    elif args.nc:
        ckpt = torch.load(ckpt_src)

        if args.net.startswith('resnet'):
            net.fc = torch.nn.Linear(2048, args.num_cls_tgt)
        elif args.net.startswith('densenet'):
            input_dim = 1664
            if args.net == 'densenet201':
                input_dim = 1920
            elif args.net == 'densenet121':
                input_dim = 1024
            net.classifier = torch.nn.Linear(input_dim, args.num_cls_tgt)
        elif args.net.startswith('efficientnet'):
            input_dim = 1280
            if args.net == 'efficientnetb2':
                input_dim = 1408
            elif args.net == 'efficientnetb3':
                input_dim = 1536
            net.classifier = nn.Sequential(
                    nn.Dropout(p=0.2, inplace=True),
                    nn.Linear(input_dim, args.num_cls_tgt),
                )
        elif args.net.startswith('monilenet'):
            net.output = torch.nn.Linear(2048, args.num_cls_tgt)
        elif args.net.startswith('vgg'):
            net.classifier = nn.Sequential(
                    nn.Linear(512 * 7 * 7, 4096),
                    nn.ReLU(True),
                    nn.Dropout(p=0.5),
                    nn.Linear(4096, 4096),
                    nn.ReLU(True),
                    nn.Dropout(p=0.5),
                    nn.Linear(4096, args.num_cls_tgt),
                )
            # net.classifier.6 = torch.nn.Linear(2048, NUM_CLS_TGT)
        
        net.load_state_dict(ckpt['net'], strict=True)
    else:
        ### densenet has laoded in build_net_imgnet_pretrain
        if not args.net.startswith('densenet'):
            ckpt = torch.load(ckpt_src)
            net.load_state_dict(ckpt)

    net = net.to(device)

    def print_args(args):
        s = "==========================================\n"
        for arg, content in args.__dict__.items():
            s += "{}:{}\n".format(arg, content)
        return s

    forward_func = build_forward_func(args)
    Estimate = build_estimater(args)

    if args.tsne:
        tsne(args, trainloader, net, forward_func, save_dir = f'singlezoo/{args.tgt}', save_name = f'{args.net}_{args.tgt}.png')
        continue

    net.eval()

    score = Estimate(trainloader, net, args, forward_func) 

    te_scores.append(score)

    log_str = f'\n[{args.src} to {args.tgt}]/[{args.net}] {args.metric} Score: {score}'

    print(log_str)
    print()
    torch.cuda.empty_cache()
    args.out_file.write(log_str+'\n')
    args.out_file.flush() 

import pandas as pd

print(te_scores)
correlation(te_scores, gt, args)
